-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Pseudo-inverse preconditioner for EqQP. #133
base: main
Are you sure you want to change the base?
Conversation
31cb232
to
55a2b81
Compare
CC @Algue-Rythme, this is pretty similar to your Jacobi preconditioner for OSQP. |
ffefbf1
to
e118328
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PseudoInverse preconditioner is a good idea if you plan on solving a sequence of similar QPs.
I left a few comments.
@@ -77,7 +77,7 @@ def test_qp_eq_only(self): | |||
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) | |||
sol = qp.run(**hyperparams).params | |||
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) | |||
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) | |||
self._check_derivative_Q_c_A_b(qp, Q, c, A, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch ! Maybe you need self._check_derivative_Q_c_A_b(qp, None, Q, c, A, b)
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the second argument from the self._check_derivative_Q_c_A_b
method since it doesn't use it actually.
|
||
|
||
def row_matvec(block, x): | ||
return sum(jax.tree_util.tree_map(jnp.dot, block, x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct, here block
is actually a tuple of blocks, and x
a tuple of vectors with the same structure ? So technicaly it is not a row_vector
product since the result is not a scalar as one would expect. Maybe add a docstring and consider renaming the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, let me know :)
[C, D]] | ||
|
||
""" | ||
return jax.tree_util.tree_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are stopping the recursion at depth 1 (i.e you only retrieve [A,B]
and C,D
) I believe it would be more clear to hardcode it instead of using tree_map
. It is a bit overkill here. For example consider writing upper_block = self.blocks[0]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True ; my hope was to leave a stub which could be extended for any sized block matrix, and not just 2x2. I can go either way on this, let me know what you think.
@jax.tree_util.register_pytree_node_class | ||
@dataclass | ||
class BlockLinearOperator: | ||
"""Represents a linear operator defined by blocks over a block pytree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the pytree
refers to ? The 2x2 tuple Tuple[Tuple[jnp.array]]
? Or is it something more complicated ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is -- without registering it as a pytree node, I get this error: type <class 'jaxopt._src.linear_operator.BlockLinearOperator'> is not a valid JAX type.
64bb1a1
to
af04a6d
Compare
c78943e
to
5a3ca42
Compare
This allows to precompute a preconditioner, and share it across multiple outer loops, where the inner loop is solving an Equality Constrained QP. This should provide speedups when the parameters of the inner loop QP don't change too much. TODO: modify the implicit diff decorator so that the jvp also uses the preconditioner.
5a3ca42
to
3126417
Compare
@@ -57,7 +57,8 @@ def eq_fun(primal_var, params_eq): | |||
|
|||
# It is required to post_process the output of `idf.make_kkt_optimality_fun` | |||
# to make the signatures of optimality_fun() and run() agree. | |||
def optimality_fun(params, params_obj, params_eq): | |||
# The M argument is needed for using preconditioners. | |||
def optimality_fun(params, params_obj, params_eq, M=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On first sight, I'm rather -1 on introducing a pre-conditioner M
in optimality_fun
since I don't think it plays any role in the optimality conditions (it would not make sense to differentiate with respect to M
). Since run
and optimality_fun
need to have the same signature, this rules out adding M
to run
as well.
I think I would go for something like this instead:
preconditioner = PseudoInversePreconditioner(params_obj, params_eq)
qp = EqualityConstrainedQP(preconditioner=preconditioner)
qp.run(params_obj=params_obj, params_eq=params_eq)
Typically, stuff that doesn't need to be differentiated should go to the constructor.
If you want to differentiate wrt params_eq
or params_obj
, you may need to use
EqualityConstrainedQP(preconditioner=lax.stop_gradient(preconditioner))
instead. Not entirely sure if PseudoInversePreconditioner
should live in JAXopt or in user land.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't play a role in the optimality conditions -- it could play a role in the backwards though, if we manage to pass the argument to the backward solver as well. This would hopefully speed things up, since the forward and backward linear systems share the linear operator.
With the current API, if we're solving the QP as the inner problem of a bi-level problem, then we need to build a new QP solver instance at each step of the outer loop, and pass solve=partial(linearsolver, M=preconditioner) to both solve
and implicit_diff_solve
.
Something which is then unclear to me: does building a new instance of the QP solver necessarily trigger recompilation of the run
method at each iteration ?
This allows to precompute a preconditioner, and share it across multiple
outer loops, where the inner loop is solving an Equality Constrained QP.
This should provide speedups when the parameters of the inner loop QP
don't change too much.
TODO: modify the implicit diff decorator so that the jvp also uses the same preconditioner, since the backward system shares the same linear operator with the forward.